from typing import List

from data import *
from julia import Main
import math
from utils import *
from sklearn.linear_model import Ridge


def full_obs_ate_multi_trials(X, y0, y1, params, nt):
    tauhat = np.zeros(nt)
    for i in range(nt):
        tauhat[i] = full_obs_ate_est(X, y0, y1, params)
    tau = sum(y1 - y0) / y0.shape[0]
    errs = np.sqrt(np.multiply(tauhat - tau, tauhat - tau))
    mean_err = sum(errs) / nt
    std_err = np.sqrt(sum(np.multiply(errs - mean_err, errs - mean_err)) / nt)
    return mean_err, std_err, mean_err / abs(tau), std_err / abs(tau)


def full_obs_ate_est(X, y0, y1, params):
    if params["method"] == 'ht_uniform':
        z = np.random.choice([-1, 1], size=y0.shape[0], p=[0.5, 0.5])
        y = (y1 - y0) + np.multiply(z, y0 + y1)
        return sum(y) / y0.shape[0]
    elif params["method"] == 'GSW':
        return GSW_full_obs_ate(X, y0, y1, params)
    elif params["method"] == 'rand_vec':
        return rand_vec_reg_adj_cross(X, y0, y1, params)
    elif params["method"] == 'classic_reg_adj':
        return classic_reg_adj(X, y0, y1, params)
    elif params["method"] == 'simple_lev':
        return simple_lev_approach(X, y0, y1, params)
    elif params["method"] == 'four_vecs':
        return four_vecs(X, y0, y1, params)


def GSW_full_obs_ate(X, y0, y1, params):
    """ Gram-Schmidt-Walk design -- a call to the previous implementation with robustness and balance parameter set to 0.5"""
    Main.eval('push!(LOAD_PATH, "./")')
    from julia import GSWDesign
    n, d = X.shape
    GSWDesign.X = np.array(X)
    GSWDesign.lamda = 0.5
    z = GSWDesign.sample_gs_walk(GSWDesign.X, GSWDesign.lamda)
    tauhat = (sum(np.multiply(z, y1)) - sum(np.multiply(1 - z, y0))) / GSWDesign.lamda / n
    return tauhat


def rand_vec_reg_adj_cross(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = np.log(y0.shape[0]) * zeta * zeta
    z1 = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # Z+ is 1 and Z- is -1
    z2 = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # S is 1 and Sbar is -1
    y = ((y0 + y1) - np.multiply(z1, y0 - y1)) / 2
    X1 = 2 * X[np.squeeze(z2==-1),:]
    yr1 = 2 * y[z2==-1]
    X2 = 2 * X[np.squeeze(z2==1),:]
    yr2 = 2 * y[z2==1]
    clf1 = Ridge(alpha=lamda, fit_intercept=False)
    clf1.fit(X1, yr1)
    clf2 = Ridge(alpha=lamda, fit_intercept=False)
    clf2.fit(X2, yr2)
    y10 = y0 - np.expand_dims(clf2.predict(X), axis = 1)
    y11 = y1 - np.expand_dims(clf2.predict(X), axis = 1)
    y20 = y0 - np.expand_dims(clf1.predict(X), axis = 1)
    y21 = y1 - np.expand_dims(clf1.predict(X), axis = 1)
    tauhat = sum(y11[np.squeeze(np.logical_and(z1 == 1, z2 == -1))]) + sum(y21[np.squeeze(np.logical_and(z1 == 1, z2 == 1))])
    tauhat = tauhat - sum(y10[np.squeeze(np.logical_and(z1 == -1, z2 == -1))]) - sum(y20[np.squeeze(np.logical_and(z1 == -1, z2 == 1))])
    tauhat = tauhat * 2 / y0.shape[0]
    return tauhat


def classic_reg_adj(X, y0, y1, params):
    # see multiple regression estimator on
    # Freedman, David A. "On regression adjustments to experimental data." 
    # Advances in Applied Mathematics 40, no. 2 (2008): 180-193.
    n, d = X.shape
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    z = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5])
    y = ((y0 + y1) - np.multiply(z, y0 - y1)) / 2
    z = (z + 1) / 2
    Xtil = np.concatenate((np.ones((y0.shape[0],1)), z, X), axis = 1)
    clf = Ridge(0, fit_intercept=False)
    clf.fit(Xtil, y)
    xtemp = np.zeros((1,Xtil.shape[1]))
    xtemp[0,1] = 1
    return clf.predict(xtemp)[0]


def simple_lev_approach(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = np.log(y0.shape[0]) * zeta * zeta
    z = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5,0.5]) # S is 1 and Sbar is -1
    y = ((y0 + y1) - np.multiply(z, y0 - y1)) / 2
    X1 = X[np.squeeze(z==-1),:]
    yr1 = y[z==-1]
    X2 = X[np.squeeze(z==1),:]
    yr2 = y[z==1]
    clf1 = Ridge(alpha=lamda, fit_intercept=False)
    clf1.fit(X1, yr1)
    clf2 = Ridge(alpha=lamda, fit_intercept=False)
    clf2.fit(X2, yr2)
    y0hat = np.expand_dims(clf1.predict(X), axis = 1)
    y1hat = np.expand_dims(clf2.predict(X), axis = 1)
    tauhat = sum(y1hat - y0hat) / y0.shape[0]
    return tauhat


def four_vecs(X, y0, y1, params):
    n, d = X.shape
    zeta = max(np.linalg.norm(X, axis=1))
    y0 = np.expand_dims(y0, axis=1)
    y1 = np.expand_dims(y1, axis=1)
    lamda = np.log(y0.shape[0]) * zeta * zeta
    z1 = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # Z+ is 1 and Z- is -1
    z2 = np.random.choice([-1, 1], size=(y0.shape[0],1), p=[0.5, 0.5]) # S is 1 and Sbar is -1
    y = ((y0 + y1) - np.multiply(z1, y0 - y1)) / 2
    
    y11 = y[np.squeeze(np.logical_and(z1 == 1, z2 == -1))]
    y21 = y[np.squeeze(np.logical_and(z1 == 1, z2 == 1))]
    y10 = y[np.squeeze(np.logical_and(z1 == -1, z2 == -1))]
    y20 = y[np.squeeze(np.logical_and(z1 == -1, z2 == 1))]
    X11 = X[np.squeeze(np.logical_and(z1 == 1, z2 == -1)), :]
    X21 = X[np.squeeze(np.logical_and(z1 == 1, z2 == 1)), :]
    X10 = X[np.squeeze(np.logical_and(z1 == -1, z2 == -1)), :]
    X20 = X[np.squeeze(np.logical_and(z1 == -1, z2 == 1)), :]

    clf11 = Ridge(alpha=lamda, fit_intercept=False)
    clf11.fit(X11, y11)
    clf21 = Ridge(alpha=lamda, fit_intercept=False)
    clf21.fit(X21, y21)
    clf10 = Ridge(alpha=lamda, fit_intercept=False)
    clf10.fit(X10, y10)
    clf20 = Ridge(alpha=lamda, fit_intercept=False)
    clf20.fit(X20, y20)

    y1hat = y - (clf21.predict(X) - clf20.predict(X))
    y2hat = y - (clf11.predict(X) - clf10.predict(X))
    tauhat = sum(y1hat[np.squeeze(np.logical_and(z1 == 1, z2 == -1))]) + sum(y2hat[np.squeeze(np.logical_and(z1 == 1, z2 == 1))])
    tauhat = tauhat - sum(y1hat[np.squeeze(np.logical_and(z1 == -1, z2 == -1))]) - sum(y2hat[np.squeeze(np.logical_and(z1 == -1, z2 == 1))])
    tauhat = tauhat * 2 / n
    return tauhat


